In [1]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
In [2]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=plt.figaspect(0.5))
ax1.plot([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
ax2.scatter([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
plt.show()
ax.margins(...)
If you'd like to add a bit of "padding" to a plot, ax.margins(<some_small_fraction>)
is a very handy way to do so. Instead of choosing "even-ish" numbers as min/max ranges for each axis, margins
will make matplotlib calculate the min/max of each axis by taking the range of the data and adding on a fractional amount of padding.
As an example: (Note that the ranges for the scatter example actually shrink slightly in this case)
In [3]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=plt.figaspect(0.5))
ax1.plot([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
ax2.scatter([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
ax1.margins(x=0.0, y=0.1) # 10% padding in the y-direction only
ax2.margins(0.05) # 5% padding in all directions
plt.show()
ax.axis(...)
The ax.axis(...)
method is a convienent way of controlling the axes limits and enabling/disabling autoscaling.
If you ever need to get all of the current plot limits, calling ax.axis()
with no arguments will return the xmin/max/etc:
xmin, xmax, ymin, ymax = ax.axis()
If you'd like to manually set all of the x/y limits at once, you can use ax.axis
for this, as well (note that we're calling it with a single argument that's a sequence, not 4 individual arguments):
ax.axis([xmin, xmax, ymin, ymax])
However, you'll probably use axis
mostly with either the "tight"
or "equal"
options. There are other options as well; see the documentation for full details. In a nutshell, though:
And as an example:
In [4]:
fig, axes = plt.subplots(nrows=3)
for ax in axes:
ax.plot([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
axes[0].set_title('Normal Autoscaling', y=0.7, x=0.8)
axes[1].set_title('ax.axis("tight")', y=0.7, x=0.8)
axes[1].axis('tight')
axes[2].set_title('ax.axis("equal")', y=0.7, x=0.8)
axes[2].axis('equal')
plt.show()
In [5]:
# Good -- setting limits after plotting is done
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=plt.figaspect(0.5))
ax1.plot([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
ax2.scatter([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
ax1.set_ylim(bottom=-10)
ax2.set_xlim(right=25)
plt.show()
In [6]:
# Bad -- Setting limits before plotting is done
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=plt.figaspect(0.5))
ax1.set_ylim(bottom=-10)
ax2.set_xlim(right=25)
ax1.plot([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
ax2.scatter([-10, -5, 0, 5, 10, 15], [-1.2, 2, 3.5, -0.3, -4, 1])
plt.show()
As you've seen in some of the examples so far, the X and Y axis can also be labeled, as well as the subplot itself via the title.
However, another thing you can label is the line/point/bar/etc that you plot. You can provide a label to your plot, which allows your legend to automatically build itself.
In [7]:
fig, ax = plt.subplots()
ax.plot([1, 2, 3, 4], [10, 20, 25, 30], label='Philadelphia')
ax.plot([1, 2, 3, 4], [30, 23, 13, 4], label='Boston')
ax.set(ylabel='Temperature (deg C)', xlabel='Time', title='A tale of two cities')
ax.legend(loc="upper left")
plt.show()
Legends will go in the upper right corner by default (you can control this with the loc
kwarg), but if you'd prefer matplotlib to choose a location to avoid overlapping plot elements as much as possible, you can pass in:
ax.legend(loc="best")
Also, if you happen to be plotting something that you do not want to appear in the legend, just set the label to "_nolegend_".
In [8]:
fig, ax = plt.subplots(1, 1)
ax.bar([1, 2, 3, 4], [10, 20, 25, 30], label="Foobar", align='center', color='lightblue')
ax.plot([1, 2, 3, 4], [10, 20, 25, 30], label="_nolegend_", marker='o', color='darkred')
ax.legend(loc='best')
plt.show()
In [9]:
# %load exercises/4.1-legends_and_scaling.py
import numpy as np
import matplotlib.pyplot as plt
# Try to reproduce the figure shown in images/exercise_4-1.png
# Here's the data and colors used.
t = np.linspace(0, 2 * np.pi, 150)
x1, y1 = np.cos(t), np.sin(t)
x2, y2 = 2 * x1, 2 * y1
colors = ['darkred', 'darkgreen']
# Try to plot the two circles, scale the axes as shown and add a legend
# Hint: it's easiest to combine `ax.axis(...)` and `ax.margins(...)` to scale
# the axes
One key thing we haven't talked about yet is all of the annotation on the outside of the axes, the borders of the axes, and how to adjust the amount of space around the axes. We won't go over every detail, but this next section should give you a reasonable working knowledge of how to configure what happens around the edges of your axes.
This is a constant source of confusion:
Ticker
automatically determines the ticks for an Axis and formats the tick labels.tick_params()
is often used to help configure your tickers.
In [10]:
fig, ax = plt.subplots()
ax.plot([1, 2, 3, 4], [10, 20, 25, 30])
# Manually set ticks and tick labels *on the x-axis* (note ax.xaxis.set, not ax.set!)
ax.xaxis.set(ticks=range(1, 5), ticklabels=[3, 100, -12, "foo"])
# Make the y-ticks a bit longer and go both in and out...
ax.tick_params(axis='y', direction='inout', length=10)
plt.show()
A commonly-asked question is "How do I plot non-numerical categories?"
The easiest way to do this is to "fake" the x-values and then change the tick labels to reflect the category.
For example:
In [11]:
data = [('apples', 2), ('oranges', 3), ('peaches', 1)]
fruit, value = zip(*data)
fig, ax = plt.subplots()
x = np.arange(len(fruit))
ax.bar(x, value, align='center', color='gray')
ax.set(xticks=x, xticklabels=fruit)
plt.show()
The spacing between the subplots can be adjusted using fig.subplots_adjust()
. Play around with the example below to see how the different arguments affect the spacing.
In [12]:
fig, axes = plt.subplots(2, 2, figsize=(9, 9))
fig.subplots_adjust(wspace=0.5, hspace=0.3,
left=0.125, right=0.9,
top=0.9, bottom=0.1)
plt.show()
A common "gotcha" is that the labels are not automatically adjusted to avoid overlapping those of another subplot. Matplotlib does not currently have any sort of robust layout engine, as it is a design decision to minimize the amount of "magic" that matplotlib performs. LaTeX users would be quite familiar with the amount of frustration that can occur with placement of figures in their documents.
That said, there have been some efforts to develop tools that users can use to help address the most common compaints. The "Tight Layout" feature, when invoked, will attempt to resize margins, and subplots so that nothing overlaps.
If you have multiple subplots, and want to avoid overlapping titles/axis labels/etc, fig.tight_layout
is a great way to do so:
In [13]:
def example_plot(ax):
ax.plot([1, 2])
ax.set_xlabel('x-label', fontsize=16)
ax.set_ylabel('y-label', fontsize=8)
ax.set_title('Title', fontsize=24)
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
example_plot(ax1)
example_plot(ax2)
example_plot(ax3)
example_plot(ax4)
# Try enabling fig.tight_layout to compare...
#fig.tight_layout()
plt.show()
Under the hood, matplotlib utilizes GridSpec
to lay out the subplots. While plt.subplots()
is fine for simple cases, sometimes you will need more advanced subplot layouts. In such cases, you should use GridSpec directly. GridSpec is outside the scope of this course, but it is handy to know that it exists. Here is a guide on how to use it.
There will be times when you want to have the x axis and/or the y axis of your subplots to be "shared". Sharing an axis means that the axis in one or more subplots will be tied together such that any change in one of the axis changes all of the other shared axes. This works very nicely with autoscaling arbitrary datasets that may have overlapping domains. Furthermore, when interacting with the plots (panning and zooming), all of the shared axes will pan and zoom automatically.
In [14]:
fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
ax1.plot([1, 2, 3, 4], [1, 2, 3, 4])
ax2.plot([3, 4, 5, 6], [6, 5, 4, 3])
plt.show()
In [15]:
fig, ax1 = plt.subplots(1, 1)
ax1.plot([1, 2, 3, 4], [1, 2, 3, 4])
ax2 = ax1.twinx()
ax2.scatter([1, 2, 3, 4], [60, 50, 40, 30])
ax1.set(xlabel='X', ylabel='First scale')
ax2.set(ylabel='Other scale')
plt.show()
Spines are the axis lines for a plot. Each plot can have four spines: "top", "bottom", "left" and "right". By default, they are set so that they frame the plot, but they can be individually positioned and configured via the set_position()
method of the spine. Here are some different configurations.
In [16]:
fig, ax = plt.subplots()
ax.plot([-2, 2, 3, 4], [-10, 20, 25, 5])
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom') # no ticklines at the top
ax.spines['right'].set_visible(False)
ax.yaxis.set_ticks_position('left') # no ticklines on the right
# "outward"
# Move the two remaining spines "out" away from the plot by 10 points
# ax.spines['bottom'].set_position(('outward', 10))
# ax.spines['left'].set_position(('outward', 10))
# "data"
# Have the spines stay intersected at (0,0)
# ax.spines['bottom'].set_position(('data', 0))
# ax.spines['left'].set_position(('data', 0))
# "axes"
# Have the two remaining spines placed at a fraction of the axes
ax.spines['bottom'].set_position(('axes', 0.5))
ax.spines['left'].set_position(('axes', 0.5))
plt.show()
This one is a bit trickier. Once again, try to reproduce the figure below:
A few key hints: The two subplots have no vertical space between them (this means that the hspace
is 0
). Note that the bottom spine is at 0 in data coordinates and the tick lines are missing from the right and top sides.
Because you're going to be doing a lot of the same things to both subplots, to avoid repitive code you might consider writing a function that takes an Axes
object and makes the spine changes, etc to it.
In [66]:
# %load exercises/4.2-spines_ticks_and_subplot_spacing.py
import matplotlib.pyplot as plt
import numpy as np
# Try to reproduce the figure shown in images/exercise_4.2.png
# This one is a bit trickier!
# Here's the data...
data = [('dogs', 4, 4), ('frogs', -3, 1), ('cats', 1, 5), ('goldfish', -2, 2)]
animals, friendliness, popularity = zip(*data)
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)
fig.subplots_adjust(hspace=0)
def _formatAxes( ax ):
ax.set(xticks=x, xticklabels=animals)
ax.xaxis.set_ticks_position('bottom')
ax.spines['bottom'].set_position('zero')
ax.yaxis.set_ticks_position('left')
ax.axhline(y=0, color='black')
x = np.arange(len(animals))
ax1.bar(x, friendliness, align='center', color='gray')
ax2.bar(x, popularity, align='center', color='gray')
ax1.set( ylabel='Friendliness', ylim = map(lambda x : 1.08 * x, [ min(friendliness), max(friendliness) ] ) )
ax2.set( ylabel='Popularity', ylim = map(lambda x : 1.08 * x, [ 0, max(popularity) ] ))
for ax in ax1, ax2:
_formatAxes(ax)
plt.show()
In [ ]: